%% @doc The database access library for the SPEWF board application.
%% Database primitive operations such as create, update and read
%% are supported, as well as convenience operations to minimize
%% database round-trips.
%%
%% @type guid() = {integer(), integer(), integer(), integer()}
%%
%% @todo Delete records. Unset fields.
%%
%% @end
-module(bd_db).

%% Basic database operations
-export([new_user/1, new_post/1,
         update_user/1, update_post/1,
         find_user/1, find_post/1,
         find_users/1, find_posts/1]).

%% Convenience operations
-export([authenticate/2,
         thread/1,
         watch/2,
         watching/1,
         latest_posts/1,
         nice/1,
         id2s/1,
         s2id/1,
         ascending/1,
         descending/1]).

%% Maintenance operations
-export([
         fix_friends/0,
         create_tables/1,
         load_test_data/1,
         drop_tables/0,
         posts/0, rawposts/0,
         users/0, rawusers/0]).

-include_lib("stdlib/include/qlc.hrl").
-include_lib("stdlib/include/ms_transform.hrl").

-include("board.hrl").

-include_lib("eunit/include/eunit.hrl").

-define(atomic(X), begin {atomic, R__} = mnesia:transaction(fun() -> X end), R__ end).

%% @spec make_guid() -> guid()
%% @doc Create a new globally unique identifier. This does not use
%% any partitioning schemes, it's purely random. GUIDs make better
%% keys than sequences because you don't need to worry about
%% concurrent access to the sequence, and you don't have the tree
%% balancing issues of using a monotonically increasing sequence
%% as a key.
make_guid() ->
   <<N1:32/integer, N2:32/integer, N3:32/integer, N4:32/integer>> =
      crypto:rand_bytes(16),
   {N1, N2, N3, N4}.

%% These ...1 functions are the "non-transactional" cores of
%% the wrapper functions which call transactions. The ...1 functions
%% call each other, and are all wrapped up in a big transaction
%% by the interface functions.
new_user1(U = #user{email = E, nickname = N}) when E /= undefined,
                                                   N /= undefined ->
   case find_user1(U) of
      false ->
         NU = U#user{id = make_guid()},
         ok = mnesia:write(NU),
         {ok, NU#user.id};
      #user{email = E} -> {error, email_in_use};
      #user{nickname = N} -> {error, nickname_in_use}
   end.

%% @spec new_user(User) -> {ok, guid()} + {error, Reason}
%%    User = #user{}
%%    Reason = any()
%% @doc Create a new user with the given parameters. Returns the
%% new user's Id.
new_user(U) ->
   ?atomic(new_user1(U)).

get_post(I) ->
   txq1(qlc:q([P || P <- mnesia:table(post),
                    P#post.id =:= I])).

txq1(Q) ->
   {atomic, Result} = mnesia:transaction(fun() -> q1(Q) end),
   Result.

q1(Q) ->
   case qlc:e(Q) of
      [R|_] ->
         R;
      [] -> false
   end.

txqall(Q) ->
   {atomic, Result} = mnesia:transaction(fun() -> qlc:e(Q) end),
   Result.

niceq(Q) ->
   lists:map(fun nice/1, txqall(Q)).

%% @spec find_users(UserSpec) -> [User]
%%    UserSpec = #user{} + guid() + string()
%%    User = #user{}
%% @doc Find the matching user records and return them in a (potentially
%% empty) list. If just a string is given, both the user nickname and
%% email fields are searched for it.
find_users(U) ->
   ?atomic(find_users1(U)).

find_users1(#user{email = E, nickname = N}) when E /= undefined,
                                                N /= undefined ->
   Q = qlc:q([ U || U <- mnesia:table(user),
                    U#user.email =:= E orelse U#user.nickname =:= N]),
   qlc:e(Q);
find_users1(Id = {N1, _, _, _}) when is_integer(N1) ->
   find_users1(#user{id = Id});
find_users1(S) when not is_record(S, user) ->
   find_users1(#user{email = S, nickname = S});
find_users1(U) ->
   mnesia:match_object(set_defaults(U, #user{_ = '_'})).

%% @spec find_user(UserSpec) -> User + not_unique + false
%%    UserSpec = #user{} + guid() + string()
%% @doc Find the exactly one matching user record. The atom ``not_unique''
%% is returned if there is more than one matching user. The atom
%% ``false'' is returned if there is no matching user.
find_user(U) ->
   ?atomic(find_user1(U)).

find_user1(U) ->
   case find_users1(U) of
      [User] ->
         User;
      [] -> false;
      _ -> not_unique
   end.

new_post1(P) ->
   NP = P#post{id = make_guid(), timestamp = calendar:universal_time()},
   ok = mnesia:write(normalize_post1(NP)),
   {ok, NP#post.id}.

normalize_user(U) ->
   ?atomic(normalize_user1(U)).

%% @spec id2s(Id::guid()) -> string()
%% @doc Return a (url safe) string representation of the GUID.
id2s(Id = {N, _, _, _}) when is_integer(N) ->
   urlsafe(base64:encode_to_string(term_to_binary(Id)));
id2s(X) -> X.

%% @spec s2id(string()) -> guid()
%% @doc Return the GUID given the string generated by {@link id2s/1}.
s2id(Id) when is_list(Id) ->
   binary_to_term(base64:decode(urlunsafe(Id))).

normalize_user1(Id = {N, _, _, _}) when is_integer(N) ->
   Id;
normalize_user1(EorN) when is_list(EorN) ->
   case find_user1(#user{nickname = EorN, email = EorN}) of
      #user{id = Id} ->
         Id;
      _ -> undefined
   end;
normalize_user1(X) -> X.

urlunsafe(L) when length(L) rem 4 > 0 ->
   urlunsafe(L ++ "=");
urlunsafe(L) ->
   lists:map(fun ($-) -> $+;
                 ($_) -> $/;
                 (C) -> C
             end, L).
                             
urlsafe(L) ->
   lists:map(fun ($+) -> $-;
                 ($/) -> $_;
                 (C) -> C
             end, string:strip(L, right, $=)).

%% @spec nice(Record) -> User + Post
%%    Record = #user{} + #post{}
%%    User = #user{}
%%    Post = #post{}
%% @doc Return a "nice" form of the given record. A "nice" record
%% is not suitable for database operations: the Ids are replaced
%% with either a user nickname or a string Id (for posts).
nice(P = #post{id = Id, author = Uid, parent = Parid}) ->
   User = find_user(Uid),
   P#post{id = id2s(Id),
          parent = id2s(Parid),
          author = User#user.nickname};
nice(U = #user{id = Uid, friends = Friends, watching = Watch}) ->
   NiceWatch = case Watch of
                  undefined -> undefined;
                  List ->
                     lists:map(fun (P) ->
                                     Post = find_post(#post{id = P}),
                                     Post#post.summary
                               end, List)
               end,
   NiceFriends = lists:map(fun (F) ->
                                 User = find_user(F),
                                 User#user.nickname
                           end, Friends),
   U#user{id = id2s(Uid),
          friends = NiceFriends,
          watching = NiceWatch}.

normalize_post1(P = #post{author = A}) when is_list(A) ->
   normalize_post1(P#post{author = normalize_user1(A)});
normalize_post1(P = #post{parent = Parid}) when is_list(Parid) ->
   normalize_post1(P#post{parent = s2id(Parid)});
normalize_post1(P = #post{id = Id}) when is_list(Id) ->
   normalize_post1(P#post{id = s2id(Id)});
normalize_post1(P = #post{}) ->
   P.

%% @spec new_post(Post::#post{}) -> {ok, guid()} + {error, Reason}
%%    Reason = any()
%% @doc Create a new post. The post will be populated with the
%% correct timestamp and new GUID.
new_post(P) ->
   ?atomic(new_post1(P)).

%% @spec update_post(PostUpdate) -> ok + Error
%%    Error = any()
%%    PostUpdate = #post{}
%% @doc Update the given post. The PostUpdate must contain the Post
%% Id field, the rest of the non-default fields are updates.
update_post(P) ->
   ?atomic(update_post1(P)).

update_post1(P) ->
   case find_post1(#post{id = P#post.id}) of
      false ->
         {error, {does_not_exist, P#post.id}};
      Old ->
         New = set_defaults(P, Old),
         mnesia:write(New)
   end.
      
%% @private
rawposts() ->
   mnesia:dirty_match_object(#post{_ = '_'}).

%% @private
rawusers() ->
   mnesia:dirty_match_object(#user{_ = '_'}).

%% @private
posts() ->
   niceq(qlc:q([ P || P <- mnesia:table(post) ])).
                            
%% @private
users() ->
   niceq(qlc:q([ U || U <- mnesia:table(user) ])).

set_defaults(Rec, Def) when element(1, Rec) =:= element(1, Def),
                            is_atom(element(1, Rec)),
                            size(Rec) == size(Def) ->
   set_defaults(Rec, Def, size(Rec)).

set_defaults(Rec, _, 0) ->
   Rec;
set_defaults(Rec, Def, N) ->
   case element(N, Rec) of
      undefined ->
         set_defaults(setelement(N, Rec, element(N, Def)), Def, N - 1);
      _ ->
         set_defaults(Rec, Def, N - 1)
   end.

%% @private
%% @spec fix_friends() -> {atomic, ok} + Error
%% @doc Walks through the user record, resolving nicknames in the
%% "friends" fields to user ids.
fix_friends() ->
   Normalize = fun (FU) ->
                     case find_user1(FU) of
                        false -> undefined;
                        #user{id = Id} ->
                           Id
                     end
               end,
   Invalid = fun (undefined) -> true;
                 (_) -> false
             end,
   Tx = fun () ->
              mnesia:foldl(
                fun (U = #user{friends = Friends}, N) ->
                      Normal =
                         lists:dropwhile(Invalid,
                                         lists:map(Normalize, Friends)),
                      case Normal of
                         Friends -> N;
                         _ ->
                            ok = mnesia:write(U#user{friends = Normal}),
                            N + 1
                      end
                end, 0, user)
        end,
   mnesia:transaction(Tx).

%% @spec find_post(PostSpec) -> Post::#post{} + not_unique + false
%%    PostSpec = #post{} + guid()
%% @doc Find the exactly one post identified by the PostSpec.
find_post(P) ->
   ?atomic(find_post1(P)).

find_posts1(Id = {N, _, _, _}) when is_integer(N) ->
   find_posts1(#post{id = Id});
find_posts1(#post{author = Author}) when is_list(Author) ->
   U = normalize_user1(Author),
   find_posts1(#post{author = U});
find_posts1(P = #post{id = Sid}) when is_list(Sid) ->
   find_posts1(#post{id = s2id(Sid)});
find_posts1(P) ->
   Match = set_defaults(P, #post{_ = '_'}),
   mnesia:match_object(Match).

find_post1(P) ->
   case find_posts1(P) of
      [R] -> R;
      [] -> false;
      _ -> not_unique
   end.

%% @spec find_posts(PostSpec) -> [Post::#post{}]
%%    PostSpec = #post{} + guid()
%% @doc Find the matching posts identified by the PostSpec. The
%% PostSpec may contain a string for the ``author'' field; if so,
%% the author is looked up and its user id used instead.
find_posts(P) ->
   ?atomic(find_posts1(P)).

build_latest_list(Post = #post{timestamp = Timestamp, id = Id},
                  {undefined, _, Max, _, _}) ->
   RSet = [{Timestamp, Id}],
   {Timestamp,
    1,
    Max,
    RSet,
    [Post]};
build_latest_list(Post = #post{timestamp = Timestamp, id = Id},
                  {Earliest, Count, Max, RSet, Results}) when Count >= Max,
                                                         Timestamp > Earliest ->
   [{_, Rm}|ShortRset] = RSet,
   NewRSet = ordsets:add_element({Timestamp, Id}, ShortRset),
   NewEarliest = element(1, hd(NewRSet)),
   {NewEarliest,
    Count,
    Max,
    NewRSet,
    [Post|lists:keydelete(Rm, 2, Results)]};
build_latest_list(Post, A = {Earliest, Count, Max, RSet, Results})
  when Count >= Max ->
   {Earliest,
    Count,
    Max,
    RSet,
    Results};
build_latest_list(Post = #post{timestamp = Timestamp, id = Id},
                  {Earliest, Count, Max, RSet, Results}) ->
   NewRSet = ordsets:add_element({Timestamp, Id}, RSet),
   NewEarliest = element(1, hd(NewRSet)),
   {NewEarliest,
    Count + 1,
    Max,
    NewRSet,
    [Post|Results]}.

order_latest_posts({_, _, _, RSet, Results}) ->
   lists:map(fun ({_, Id}) ->
                   case lists:keysearch(Id, 2, Results) of
                      {value, Post} ->
                         Post;
                      _ ->
                         internal_inconsistency
                   end
             end, lists:reverse(RSet)).
                
%% @spec latest_posts(Max::integer()) -> [Post]
%%    Post = #post{}
%% @doc Retrieve the latest Max posts. The posts are returned
%% in most-recent-first ("descending") order.
latest_posts(Max) ->
   Tx = fun() ->
         mnesia:foldl(fun (P, Acc) ->
                            build_latest_list(P, Acc)
                      end,
                      {undefined, 0, Max, [], []},
                      post)
   end,
   {atomic, Results} = mnesia:transaction(Tx),
   order_latest_posts(Results).

%% @spec thread(PostSpec) -> Thread + not_unique + false
%%    PostSpec = #post{} + guid()
%%    Thread = {Post, [Post]}
%%    Post = #post{}
%% @doc Retrieve a tree of posts representing the thread of
%% discussion rooted at the post identified by the PostSpec.
thread(P) ->
   ?atomic(thread1(P)).

thread1(Pq = #post{}) ->
   P = find_post1(Pq),
   if is_record(P, post) ->
         {P, lists:map(fun thread/1,
                       ascending(
                         find_posts1(#post{parent = P#post.id})))};
      true -> P
   end.


%% @spec ascending([Post]) -> [Post]
%%    Post = #post{}
%% @doc Sort posts in oldest-first order.
ascending(Posts) ->
   lists:sort(fun (A, B) ->
                    A#post.timestamp < B#post.timestamp
              end, Posts).

%% @spec descending([Post]) -> [Post]
%%     Post = #post{}
%% @doc Sort posts in newest-first order.
descending(Posts) ->
   lists:sort(fun (A, B) ->
                    A#post.timestamp >= B#post.timestamp
              end, Posts).

%% @spec watching(User) -> [Post]
%%    User = #user{} + guid() + string()
%% @doc Returns a list of posts the user is watching, in "descending"
%% (newest-first) order.
watching(User) ->
   ?atomic(watching1(User)).

watching1(U) when not is_record(U, user) ->
   watching1(#user{id = normalize_user1(U)});
watching1(U = #user{}) ->
   User = find_user1(U),
   case User#user.watching of
      undefined -> [];
      Watching -> descending(lists:map(fun find_post1/1, Watching))
   end.

%% @spec watch(User, Post) -> ok + Error
%%    User = #user{} + guid() + string()
%%    Post = #post{} + guid()
%%    Error = any()
%% @doc Update the user record to reflect the user's watching of the
%% specified post.
watch(U, P) ->
   ?atomic(watch1(U, P)).

watch1(U, P) ->
   User = find_user1(U),
   Post = find_post1(P),
   Watching = case User#user.watching of
                 undefined -> [Post#post.id];
                 Old when is_list(Old) -> 
                    ordsets:add_element(Post#post.id, Old)
              end,
   update_user(#user{id = User#user.id, watching = Watching}).

%% @spec authenticate(Username, Password) -> {ok, guid()} + Error
%%    Username = string()
%%    Password = string()
%%    Error = {wrong_password, guid()} + {nonexistent_user, string()}
%% @doc Check the username and password against the user records
%% in the database.
%%
%% WARNING: This function distinguishes between an incorrect password
%% and a nonexistent username. It is considered bad security practice
%% to expose this disctinction to the user (and therefore attacker).
%% The distinction is made in this function for applications such
%% as the generation of an event that locks accounts when too many
%% password failures are encountered (though that's bad security
%% practice too).
authenticate(Username, Password) ->
   case find_user(Username) of
      #user{id = Uid, password = Password} ->
         {ok, Uid};
      #user{id = Uid} ->
         {wrong_password, Uid};
      false ->
         {nonexistent_user, Username}
   end.

%% @spec update_user(UserUpdate) -> ok + Error
%%    Error = any()
%%    UserUpdate = #user{}
%% @doc Update the user record identified by the UserUpdate specification.
%% It must contain the user id: all other non-default fields are used to
%% update the database record.
update_user(U) ->
   ?atomic(update_user1(U)).

update_user1(U = #user{id = Uid}) when Uid /= undefined ->
   case find_users1(#user{id = Uid}) of
      [User] ->
         NewUser = set_defaults(U, User),
         mnesia:write(NewUser);
      [] ->
         {error, no_such_user};
      L when is_list(L) -> {error, not_unique};
      Err -> {error, Err}
   end.

%% @spec create_tables(Nodes) -> [Result]
%%    Result = {atomic, ok} + Error
%%    Error = any()
%%    Nodes = local + [node()]
%% @doc Create the application tables with disc_copies on the specified nodes.
create_tables(local) ->
   create_tables([node()]);
create_tables(Nodes) ->
   [
    mnesia:create_table(post, [{attributes, record_info(fields, post)},
                               {disc_copies, Nodes}]),
    mnesia:add_table_index(post, parent),
    mnesia:create_table(user, [{attributes, record_info(fields, user)},
                               {disc_copies, Nodes}]),
    mnesia:add_table_index(user, nickname),
    mnesia:add_table_index(user, email)
   ].

sleep(N) ->
   receive
      after N ->
            ok
      end.

%% @spec load_test_data(File::filename()) -> [Result]
%%    Result = {users, integer()} + {posts, integer()} + {rejects, [Reject]}
%%    Reject = {Error, Record} + term()
%%    Record = #user{} + #post{}
%% @doc Load test data from a file containing terms. The terms are
%% incomplete records and other update specifiers.
load_test_data(File) ->
   {ok, Data} = file:consult(File),
   F = fun (User, {Us, Ps, Rej}) when is_record(User, user) ->
             case new_user(User) of
                {ok, _} -> {Us + 1, Ps, Rej};
                Err -> {Us, Ps, [{Err, User}|Rej]}
             end;
           (Post, {Us, Ps, Rej}) when is_record(Post, post) ->
             sleep(1000),
             case new_post(Post) of
                {ok, _} -> {Us, Ps + 1, Rej};
                Err -> {Us, Ps, [{Err, Post}|Rej]}
             end;
           (New = {new_post, _, _, Author,
                   {summary, ParSum}, Summary, Text}, {Us, Ps, Rej}) ->
             sleep(1000),
             try
                Parent = find_post(#post{summary = ParSum}),
                {ok, _} = new_post(#post{author = Author,
                                         parent = Parent#post.id,
                                         summary = Summary,
                                         text = Text}),
                {Us, Ps + 1, Rej}
             catch
                Err ->
                   {Us, Ps, [{Err, New}|Rej]}
             end;
           (Reject, {Us, Ps, Rej}) ->
             {Us, Ps, [Reject|Rej]}
       end,
   {Users, Posts, Rejects} = lists:foldl(F, {0, 0, []}, Data),
   [{users, Users}, {posts, Posts}, {rejects, lists:reverse(Rejects)}].
                     
%% @spec drop_tables() -> [Result]
%%    Result = {atomic, ok} + Error
%%    Error = any()
%% @doc Globally drop the application tables.
%%
%% WARNING: Yes, you can do this.
drop_tables() ->
   lists:map(fun mnesia:delete_table/1, [post, user]).

op_test_() ->
    [
     ?_assertMatch([{atomic, ok},
                    {atomic, ok},
                    {atomic, ok},
                    {atomic, ok},
                    {atomic, ok}],
                   create_tables(local)),
     {timeout, 45,
      ?_assertMatch([{users, 4}, {posts, 5}, {rejects, []}],
                    load_test_data("example/test.txt"))},
     ?_assertMatch({ok, {_, _, _, _}},
                   authenticate("partdavid", "common")),
     ?_assertMatch({ok, {_, _, _, _}},
                   authenticate("partdavid@gmail.com", "common")),
     ?_assertMatch({nonexistent_user, "foo"},
                   authenticate("foo", "bar")),
     ?_assertMatch({wrong_password, {_, _, _, _}},
                   authenticate("partdavid", "foo")),
     ?_assertMatch([#post{summary = "Grate Shakes is a band!"}],
                   latest_posts(1)),
     ?_assertMatch([#post{summary = "Grate Shakes is a band!"},
                    #post{summary = "Re: Re: From Grate Shakes"}],
                   latest_posts(2)),
     ?_assertMatch(ok,
                   watch(#user{nickname = "partdavid"},
                         #post{summary = "Re: From Grate Shakes"})),
     ?_assertMatch([#post{summary = "Re: From Grate Shakes"}],
                   watching(#user{nickname = "partdavid"})),
     ?_assertMatch({#post{summary = "Welcome"}, []},
                   thread(#post{summary = "Welcome"})),
     ?_assertMatch({#post{summary = "Re: From Grate Shakes"},
                    [{#post{summary = "Re: Re: From Grate Shakes"},
                      []}]},
                   thread(#post{summary = "Re: From Grate Shakes"})),
     ?_assertMatch({#post{summary = "From Grate Shakes"},
                    [{#post{summary = "Re: From Grate Shakes"},
                      [{#post{summary = "Re: Re: From Grate Shakes"},
                        []}]},
                     {#post{summary = "Grate Shakes is a band!"}, []}]},
                   thread(#post{summary = "From Grate Shakes"})),
     ?_assertMatch([{atomic, ok}, {atomic, ok}],
                   drop_tables())
    ].
